Give basic insights into your numeric variable you have picked as output variable using one categorical variable you selected.
group_by and summarize functions.4 variables are chosen: score, rank, members, genre
library('tidyverse')
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.3 v purrr 0.3.4
## v tibble 3.0.5 v dplyr 1.0.3
## v tidyr 1.1.2 v stringr 1.4.0
## v readr 1.4.0 v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
Anime<- read.csv('data/anime.csv') %>%
na.omit(anime)
anime <- subset(Anime,
genre %in% c('Action', 'Drama', 'Space', 'Comedy', 'Supernatural', 'Fantasy'))
rmarkdown::paged_table(anime)
anime %>%
select(score) %>%
summarize(score_max = max(score),
score_min = min(score))
anime %>%
select(score, genre) %>%
group_by(genre) %>%
summarise(score_median = median(score),
score_max = max(score),
score_min = min(score))
Visualize the variables you selected.
ggplot(anime, aes(x=score)) +
geom_histogram(colour='white')
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
ggplot(anime, aes(x=score, y=genre, colour=genre)) +
geom_boxplot()
ggplot(anime, aes(y=score, x=rank, colour=genre)) +
geom_point()
The plot does indicate a relationship between score and rank. There is a linear region in the data from approximately ranks 1000-8000. The other regions exhibit a more exponential relationship. There are several outliers in the data, which can be completely removed by filtering using scored_by > 1000. These scores are outliers as they have not been scored by enough people to be statistically significant hence why the rank of the anime is so low.
Using the all dataset, fit a regression:
fit1 <- lm(score ~ rank, data = Anime)
fit1
##
## Call:
## lm(formula = score ~ rank, data = Anime)
##
## Coefficients:
## (Intercept) rank
## 8.1772099 -0.0003044
summary(fit1)
##
## Call:
## lm(formula = score ~ rank, data = Anime)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.39032 -0.12739 -0.06236 0.07095 2.69402
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 8.177e+00 1.933e-03 4229.3 <2e-16 ***
## rank -3.044e-04 5.516e-07 -551.9 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.2013 on 26904 degrees of freedom
## Multiple R-squared: 0.9188, Adjusted R-squared: 0.9188
## F-statistic: 3.046e+05 on 1 and 26904 DF, p-value: < 2.2e-16
preds <- predict(fit1)
ggplot(Anime, aes(y=score, x=rank)) +
geom_point(alpha=.3, size=2) +
geom_line(aes(y=preds), colour="red")
Low residual standard error (0.2013) and high R-squared value (0.9188) in this case demonstrates that this is a good fit. Also, the Pr(>|t|) value is relatively small indicating the current feature(rank) has significant statistically impact on the target variable. However, the estimated coefficient is very small meaning it does not actually influence the target variable that much.
library('ggExtra')
g <- ggplot(Anime, aes(y=log(score), x=log(rank), colour=genre)) +
geom_point(alpha=.3) +
theme(legend.position = 'none')
ggMarginal(g, type='histogram')
2. Using all your input variables, fit a multiple linear regression model
fit2 <- lm(score ~ rank + members + genre, data = Anime)
summary(fit2)$coefficients
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 8.064790e+00 4.598906e-03 1753.63235597 0.000000e+00
## rank -2.892395e-04 5.917560e-07 -488.78166826 0.000000e+00
## members 3.206731e-07 5.921933e-09 54.15007345 0.000000e+00
## genreAdventure 2.206498e-02 6.869333e-03 3.21209886 1.319243e-03
## genreCars 1.502186e-01 7.152279e-02 2.10029009 3.571263e-02
## genreComedy 1.866423e-02 5.113432e-03 3.65003872 2.626977e-04
## genreDementia -2.272931e-02 3.051278e-02 -0.74491114 4.563320e-01
## genreDemons -2.046622e-02 1.180110e-02 -1.73426386 8.288275e-02
## genreDrama 3.285056e-02 6.160015e-03 5.33287021 9.745312e-08
## genreEcchi -3.559035e-02 7.858602e-03 -4.52883928 5.956344e-06
## genreFantasy 2.303485e-03 6.184754e-03 0.37244567 7.095640e-01
## genreGame -6.117883e-04 1.421349e-02 -0.04304279 9.656677e-01
## genreHarem -6.487765e-02 9.101932e-03 -7.12789819 1.044856e-12
## genreHistorical 8.651944e-02 1.053504e-02 8.21254247 2.261832e-16
## genreHorror -3.499210e-02 1.330241e-02 -2.63050888 8.530578e-03
## genreJosei 9.962592e-02 1.762107e-02 5.65379409 1.585356e-08
## genreKids 2.221694e-01 1.832222e-02 12.12568046 9.461180e-34
## genreMagic -2.367664e-02 8.401793e-03 -2.81804641 4.835212e-03
## genreMartial Arts -1.866746e-02 1.708683e-02 -1.09250554 2.746208e-01
## genreMecha 8.422511e-03 9.195293e-03 0.91595901 3.596966e-01
## genreMilitary -2.024578e-02 1.081085e-02 -1.87272719 6.111692e-02
## genreMusic 9.453884e-03 1.171731e-02 0.80683058 4.197712e-01
## genreMystery 1.136020e-02 8.476921e-03 1.34013289 1.802135e-01
## genreParody 1.168169e-01 1.340795e-02 8.71250923 3.140022e-18
## genrePolice 4.638401e-02 1.998000e-02 2.32152209 2.026613e-02
## genrePsychological 1.122076e-02 1.155004e-02 0.97149116 3.313125e-01
## genreRomance -2.480369e-02 5.954709e-03 -4.16539033 3.117906e-05
## genreSamurai 2.510075e-01 1.939550e-02 12.94153661 3.412664e-38
## genreSchool -2.438930e-03 5.941759e-03 -0.41047276 6.814625e-01
## genreSci-Fi 3.254618e-02 6.510465e-03 4.99905639 5.797370e-07
## genreSeinen 8.350385e-03 8.430738e-03 0.99046897 3.219539e-01
## genreShoujo 2.568592e-02 1.050073e-02 2.44610805 1.444714e-02
## genreShoujo Ai -2.903904e-02 1.911432e-02 -1.51922933 1.287166e-01
## genreShounen 4.890617e-02 6.507456e-03 7.51540514 5.850305e-14
## genreShounen Ai 8.077249e-03 2.294008e-02 0.35210213 7.247644e-01
## genreSlice of Life 3.310655e-02 6.593456e-03 5.02112264 5.169880e-07
## genreSpace 6.249208e-03 1.302150e-02 0.47991467 6.312920e-01
## genreSports 8.314489e-02 1.107374e-02 7.50829680 6.176132e-14
## genreSuper Power -6.980974e-04 9.835308e-03 -0.07097871 9.434152e-01
## genreSupernatural 1.210736e-02 6.614702e-03 1.83037140 6.720550e-02
## genreThriller 3.854767e-02 1.781419e-02 2.16387470 3.048275e-02
## genreVampire -2.754041e-02 1.740662e-02 -1.58217984 1.136203e-01
sigma(fit1)
## [1] 0.2013438
sigma(fit2)
## [1] 0.1889275
RMSE of this fit is slightly lower than fit 1 which indicates it is a better fit. The p-value estimates for coefficient of rank and members are relative smaller comparing to the other variable. This means rank and members seem to have more significant impacts towards the score.
fit3 <- lm(score ~ rank + members + genre + members:genre, data = Anime)
summary(fit3)$coefficients
## Estimate Std. Error t value
## (Intercept) 8.071129e+00 5.368449e-03 1503.43776964
## rank -2.890416e-04 5.891624e-07 -490.59748201
## members 2.864419e-07 1.540740e-08 18.59118607
## genreAdventure 2.537913e-02 8.369535e-03 3.03232261
## genreCars 1.324192e-01 8.590812e-02 1.54140448
## genreComedy 1.828893e-02 6.512305e-03 2.80836473
## genreDementia -1.091502e-01 4.852563e-02 -2.24933101
## genreDemons 2.575254e-02 1.635465e-02 1.57463105
## genreDrama 8.207269e-03 7.790391e-03 1.05351184
## genreEcchi -7.675581e-03 1.124357e-02 -0.68266436
## genreFantasy 1.090625e-02 7.821238e-03 1.39443993
## genreGame 5.876095e-02 1.717118e-02 3.42206810
## genreHarem 2.542673e-02 1.633557e-02 1.55652574
## genreHistorical 4.704492e-03 1.368841e-02 0.34368423
## genreHorror -1.318121e-03 1.672278e-02 -0.07882190
## genreJosei -6.541574e-02 3.241588e-02 -2.01801545
## genreKids 2.718596e-01 2.003507e-02 13.56918171
## genreMagic -2.538177e-02 1.068231e-02 -2.37605607
## genreMartial Arts -3.075578e-02 2.081436e-02 -1.47762344
## genreMecha -2.415542e-02 1.064827e-02 -2.26848348
## genreMilitary -6.202107e-02 1.293799e-02 -4.79371616
## genreMusic -4.780260e-02 1.400225e-02 -3.41392270
## genreMystery 4.471270e-03 1.081293e-02 0.41351136
## genreParody 6.063562e-02 1.647834e-02 3.67971663
## genrePolice 3.001724e-02 2.242003e-02 1.33885788
## genrePsychological 1.665750e-02 1.571931e-02 1.05968437
## genreRomance 2.873430e-03 8.135822e-03 0.35318252
## genreSamurai 6.482235e-02 2.616395e-02 2.47754490
## genreSchool -4.825457e-03 8.079733e-03 -0.59722973
## genreSci-Fi -1.461046e-03 7.970907e-03 -0.18329729
## genreSeinen -1.026821e-02 1.070213e-02 -0.95945475
## genreShoujo 2.198074e-02 1.413246e-02 1.55533742
## genreShoujo Ai 7.602117e-03 3.533480e-02 0.21514532
## genreShounen 3.731012e-02 8.261073e-03 4.51637698
## genreShounen Ai -7.844700e-02 4.466534e-02 -1.75632825
## genreSlice of Life 1.983797e-02 8.618769e-03 2.30171753
## genreSpace -1.335239e-02 1.662854e-02 -0.80298038
## genreSports 3.412500e-02 1.418641e-02 2.40547138
## genreSuper Power -5.800922e-02 1.293030e-02 -4.48629954
## genreSupernatural 1.727032e-02 8.645928e-03 1.99750879
## genreThriller 1.276377e-02 2.637383e-02 0.48395587
## genreVampire 8.725701e-03 2.845791e-02 0.30661775
## members:genreAdventure -2.616619e-08 2.741164e-08 -0.95456506
## members:genreCars 4.483061e-07 1.927153e-06 0.23262610
## members:genreComedy -5.338724e-09 2.273723e-08 -0.23480094
## members:genreDementia 2.913668e-07 1.237870e-07 2.35377537
## members:genreDemons -2.810710e-07 6.764222e-08 -4.15525938
## members:genreDrama 1.344607e-07 2.617179e-08 5.13761915
## members:genreEcchi -1.440384e-07 4.127439e-08 -3.48977601
## members:genreFantasy -5.112192e-08 2.595372e-08 -1.96973361
## members:genreGame -2.722581e-07 4.625332e-08 -5.88623863
## members:genreHarem -4.224639e-07 6.424262e-08 -6.57606889
## members:genreHistorical 6.463491e-07 7.023310e-08 9.20291274
## members:genreHorror -1.740007e-07 5.287412e-08 -3.29084790
## members:genreJosei 1.375233e-06 2.311319e-07 5.94999267
## members:genreKids -1.891067e-06 2.746281e-07 -6.88591886
## members:genreMagic 3.740772e-09 3.837899e-08 0.09746927
## members:genreMartial Arts 7.164296e-08 8.691063e-08 0.82432912
## members:genreMecha 2.638859e-07 4.486967e-08 5.88116404
## members:genreMilitary 2.203933e-07 3.863224e-08 5.70490707
## members:genreMusic 5.817949e-07 8.121096e-08 7.16399416
## members:genreMystery 3.453200e-08 2.953081e-08 1.16935486
## members:genreParody 2.956041e-07 5.177035e-08 5.70991074
## members:genrePolice 7.912177e-08 5.183963e-08 1.52627967
## members:genrePsychological -2.746345e-09 3.499037e-08 -0.07848859
## members:genreRomance -1.486106e-07 2.928550e-08 -5.07454576
## members:genreSamurai 1.206839e-06 1.160449e-07 10.39975271
## members:genreSchool 1.227654e-08 2.732117e-08 0.44934162
## members:genreSci-Fi 2.246829e-07 2.965438e-08 7.57672171
## members:genreSeinen 1.060609e-07 3.956686e-08 2.68054917
## members:genreShoujo 1.043617e-08 7.244189e-08 0.14406271
## members:genreShoujo Ai -3.985545e-07 2.953864e-07 -1.34926490
## members:genreShounen 5.663212e-08 2.461543e-08 2.30067532
## members:genreShounen Ai 1.218405e-06 5.703170e-07 2.13636476
## members:genreSlice of Life 8.026611e-08 3.737330e-08 2.14768600
## members:genreSpace 2.284411e-07 1.564424e-07 1.46022496
## members:genreSports 4.354370e-07 8.291270e-08 5.25175248
## members:genreSuper Power 2.029871e-07 3.073352e-08 6.60474607
## members:genreSupernatural -1.578621e-08 2.510199e-08 -0.62888284
## members:genreThriller 7.367298e-08 4.289820e-08 1.71739116
## members:genreVampire -1.773181e-07 1.112775e-07 -1.59347715
## Pr(>|t|)
## (Intercept) 0.000000e+00
## rank 0.000000e+00
## members 1.149444e-76
## genreAdventure 2.429114e-03
## genreCars 1.232302e-01
## genreComedy 4.982977e-03
## genreDementia 2.449952e-02
## genreDemons 1.153534e-01
## genreDrama 2.921160e-01
## genreEcchi 4.948249e-01
## genreFantasy 1.631963e-01
## genreGame 6.223945e-04
## genreHarem 1.195949e-01
## genreHistorical 7.310865e-01
## genreHorror 9.371749e-01
## genreJosei 4.359959e-02
## genreKids 8.382920e-42
## genreMagic 1.750578e-02
## genreMartial Arts 1.395203e-01
## genreMecha 2.330767e-02
## genreMilitary 1.645964e-06
## genreMusic 6.412942e-04
## genreMystery 6.792353e-01
## genreParody 2.339501e-04
## genrePolice 1.806283e-01
## genrePsychological 2.892978e-01
## genreRomance 7.239544e-01
## genreSamurai 1.323509e-02
## genreSchool 5.503591e-01
## genreSci-Fi 8.545661e-01
## genreSeinen 3.373384e-01
## genreShoujo 1.198775e-01
## genreShoujo Ai 8.296557e-01
## genreShounen 6.317450e-06
## genreShounen Ai 7.904379e-02
## genreSlice of Life 2.135873e-02
## genreSpace 4.219932e-01
## genreSports 1.615834e-02
## genreSuper Power 7.277162e-06
## genreSupernatural 4.578002e-02
## genreThriller 6.284211e-01
## genreVampire 7.591367e-01
## members:genreAdventure 3.398063e-01
## members:genreCars 8.160535e-01
## members:genreComedy 8.143650e-01
## members:genreDementia 1.859101e-02
## members:genreDemons 3.259251e-05
## members:genreDrama 2.801881e-07
## members:genreEcchi 4.842012e-04
## members:genreFantasy 4.887918e-02
## members:genreGame 3.997798e-09
## members:genreHarem 4.919245e-11
## members:genreHistorical 3.729927e-20
## members:genreHorror 1.000148e-03
## members:genreJosei 2.714792e-09
## members:genreKids 5.868132e-12
## members:genreMagic 9.223545e-01
## members:genreMartial Arts 4.097599e-01
## members:genreMecha 4.122172e-09
## members:genreMilitary 1.176309e-08
## members:genreMusic 8.038077e-13
## members:genreMystery 2.422711e-01
## members:genreParody 1.142284e-08
## members:genrePolice 1.269520e-01
## members:genrePsychological 9.374400e-01
## members:genreRomance 3.910139e-07
## members:genreSamurai 2.776835e-25
## members:genreSchool 6.531889e-01
## members:genreSci-Fi 3.658212e-14
## members:genreSeinen 7.354638e-03
## members:genreShoujo 8.854520e-01
## members:genreShoujo Ai 1.772633e-01
## members:genreShounen 2.141762e-02
## members:genreShounen Ai 3.265872e-02
## members:genreSlice of Life 3.174764e-02
## members:genreSpace 1.442400e-01
## members:genreSports 1.518069e-07
## members:genreSuper Power 4.056427e-11
## members:genreSupernatural 5.294311e-01
## members:genreThriller 8.591927e-02
## members:genreVampire 1.110650e-01
sigma(fit3)
## [1] 0.1862152
Adding more variables decreased the RMSE value demonstrating it is a better fit than fit2. All variables have relative small p values. Meantime, it is noticeable that the estimated coefficients of genre alone seems to be relatively higher than that of the interacting variable, members:genre. This shows the impact of interaction between members and genre may not influence the score that much.
In this section, you will do the same you did in 1.3, but this time you will first split the data into train and test.
set.seed(...).set.seed(156) # Set seed is needed if we want
# to get the same random numbers always
train_size <- floor(0.8 * nrow(Anime))
train_inds <- sample(1:nrow(Anime), size = train_size)
test_inds <- setdiff(1:nrow(Anime), train_inds)
train <- Anime[ train_inds , ]
test <- Anime[ test_inds , ]
cat('train size:', nrow(train), '\ntest size:', nrow(test))
## train size: 21524
## test size: 5382
library('caret')
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
fit1 <- lm(score ~ rank, data = train)
fit2 <- lm(score ~ rank + members + genre, data = train)
fit3 <- lm(score ~ rank + members + genre + members:genre, data = train)
pred1 <- predict(fit1, newdata=test)
pred2 <- predict(fit2, newdata=test)
pred3 <- predict(fit3, newdata=test)
rmse1 <- RMSE(pred1, test$score)
rmse2 <- RMSE(pred2, test$score)
rmse3 <- RMSE(pred3, test$score)
rmses <- c(rmse1,rmse2,rmse3)
rmses
## [1] 0.2013877 0.1891842 0.1869903
The models are ranked as the following. Performance rating: fit3 > fit2 > fit1. This is the same as the one we had in 1.3.
In case you want to take a picture (screenshot) of your notebook (tablet), you can use the below lines to embed the image to the output PDF file:
library('knitr')
#X
X <- c(185,175,170)
#Y
Y <- c(105,73,66)
column.names <- c("Height (cm)","Weight (kg)")
row.names <- c("","","")
result <- array(c(X,Y),dim = c(3,2),dimnames = list(row.names,column.names))
kable(result,caption = 'Simple Dataset')
| Height (cm) | Weight (kg) | |
|---|---|---|
| 185 | 105 | |
| 175 | 73 | |
| 170 | 66 |
Hypothesis function \(h_{\theta}(x) = {\theta_0} +{\theta_1}x\) Let \({\theta_0}\) = 0 \({\theta_1}\) = 0 \(\alpha\) = 0.0001
knitr::include_graphics('gd.jpg')
You will use horsepower as input variable and miles per gallon (mpg) as output:
mpg (\(Y\)) and horsepower (\(X\)).
library('ISLR')
ggplot(Auto, aes(y=mpg, x=horsepower)) +
geom_point()
The relationship is negative where as horsepower increases, mpg decreases. The relationship is non-linear and looks to be logarithmic. 2. Plot the scatterplot between
log(mpg) and log(horsepower). - Is the relationship positive or negative? - Is the relationship linear?
ggplot(Auto, aes(y=log(mpg), x=log(horsepower))) +
geom_point()
The relationship is negative, and linear, which makes sense after applying the log() function on the data above which seemed to have a logarithmic relationship.
3. Which of the two versions is better for linear regression? The log data is better for linear regression
The code below estimates the coefficients of linear regression using gradient descent algorithm. If you are given a single linear regression model;
\[Y = \beta_0 + \beta_1 X \]
where \(Y=[Y_1,\dots,Y_N]^T\) and \(X=[X_1,\dots,X_N]^T\) are output and input vectors containing the observations.
The algorithm estimates the parameter vector \(\theta = [\beta_0,\beta_1]\) by starting with an arbitrary \(\theta_0\) and adjusting it with the gradient of the loss function as:
\[\theta := \theta + \frac \alpha N X^T(Y - \theta X)\]
where \(\alpha\) is the step size (or learning rate) and \((Y - \theta X)^T X\) is the gradient. At each step it calculates the gradient of the loss and adjusts the parameter set accordingly.
GDA <- function(x, y, theta0, alpha = 0.01, epsilon = 1e-8, max_iter=25000){
# Inputs
# x : The input variables (M columns)
# y : Output variables (1 column)
# theta0 : Initial weight vector (M+1 columns)
x <- as.matrix(x)
y <- as.matrix(y)
N <- nrow(x)
i <- 0
theta <- theta0
x <- cbind(1, x) # Adding 1 as first column for intercept
imprv <- 1e10
cost <- (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y)
delta <- 1
while(imprv > epsilon & i < max_iter){
i <- i + 1
grad <- (t(x) %*% (y-x %*% theta))
theta <- theta + (alpha / N) * grad
cost <- append(cost, (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y))
imprv <- abs(cost[i+1] - cost[i])
if((cost[i+1] - cost[i]) > 0) stop("Cost is increasing. Try reducing alpha.")
}
if (i==max_iter){print(paste0("maximum interation ", max_iter, " was reached"))} else {
print(paste0("Finished in ", i, " iterations"))
}
return(theta)
}
plot_line <- function(theta) {
ggplot(Auto, aes(x=log(horsepower),y=log(mpg))) +
geom_point(alpha=.7) +
geom_abline(slope = theta[2], intercept = theta[1], colour='firebrick') +
ggtitle(paste0('int: ', round(theta[1],2), ', slope: ', round(theta[2],2)))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
theta0 <- c(1,1)
theta <- GDA(x, y, theta0, alpha = 0.05, epsilon = 1e-5)
## [1] "Finished in 3193 iterations"
plot_line(theta)
It took 3193 iterations to converge on the parameters.
1e-6, set alpha=0.05 run the code.
plot_line <- function(theta) {
ggplot(Auto, aes(x=log(horsepower),y=log(mpg))) +
geom_point(alpha=.7) +
geom_abline(slope = theta[2], intercept = theta[1], colour='firebrick') +
ggtitle(paste0('int: ', round(theta[1],2), ', slope: ', round(theta[2],2)))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
theta0 <- c(1,1)
theta <- GDA(x, y, theta0, alpha = 0.05, epsilon = 1e-6)
## [1] "Finished in 7531 iterations"
plot_line(theta)
After lowering the learning rate, it took 7531 iterations for the algorithm to converge on the parameters. However, by reducing epsilon, the result seems to improve which is counter-intuitive as the algorithm should stop earlier due to the improvements to the function being stopped earlier.
3. Reduce alpha to alpha=0.01 - How many iterations did it take? - Did the resulting line change? Why or why not?
plot_line <- function(theta) {
ggplot(Auto, aes(x=log(horsepower),y=log(mpg))) +
geom_point(alpha=.7) +
geom_abline(slope = theta[2], intercept = theta[1], colour='firebrick') +
ggtitle(paste0('int: ', round(theta[1],2), ', slope: ', round(theta[2],2)))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
theta0 <- c(1,1)
theta <- GDA(x, y, theta0, alpha = 0.01, epsilon = 1e-6)
## [1] "Finished in 22490 iterations"
plot_line(theta)
22490 iterations The line changes slightly
alpha=0.05 and try theta0=c(1,1) vs. theta0=c(1,-1):
plot_line <- function(theta) {
ggplot(Auto, aes(x=log(horsepower),y=log(mpg))) +
geom_point(alpha=.7) +
geom_abline(slope = theta[2], intercept = theta[1], colour='firebrick') +
ggtitle(paste0('int: ', round(theta[1],2), ', slope: ', round(theta[2],2)))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
theta0 <- c(1,1)
theta <- GDA(x, y, theta0, alpha = 0.05, epsilon = 1e-6)
## [1] "Finished in 7531 iterations"
plot_line(theta)
theta0 <- c(1,-1)
theta <- GDA(x, y, theta0, alpha = 0.05, epsilon = 1e-6)
## [1] "Finished in 7265 iterations"
plot_line(theta)
Changing the initial slope to be negative results in 7265 iterations which is less than the positive one which took 7531 iterations. This is because our initial guess is now closer to the actual data and should converge quicker.
epsilon = 1e-8 and try alpha=0.01, alpha=0.05 and alpha=0.1.
plot_line <- function(theta) {
ggplot(Auto, aes(x=log(horsepower),y=log(mpg))) +
geom_point(alpha=.7) +
geom_abline(slope = theta[2], intercept = theta[1], colour='firebrick') +
ggtitle(paste0('int: ', round(theta[1],2), ', slope: ', round(theta[2],2)))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
theta0 <- c(1,1)
#theta <- GDA(x, y, theta0, alpha = 0.01, epsilon = 1e-8)
#plot_line(theta)
#theta1 <- GDA(x, y, theta0, alpha = 0.1, epsilon = 1e-8)
#plot_line(theta1)
theta2 <- GDA(x, y, theta0, alpha = 0.05, epsilon = 1e-8)
## [1] "Finished in 16207 iterations"
plot_line(theta2)
When running on a high alpha value, it can result in divergence where we will hit the maximum iterations. However, when it is too low in case alpha = 0.01, it can also not converge fast enough where the result will be sub-optimal after the maximum iterations, this is seen in the graph with the line not as accurately representing the data compared to alpha = 0.05
BGD <- function(x, y, theta0, alpha = 0.01, epsilon = 1e-8, max_iter=25000){
# Inputs
# x : The input variables (M columns)
# y : Output variables (1 column)
# theta0 : Initial weight vector (M+1 columns)
x <- as.matrix(x)
y <- as.matrix(y)
N <- nrow(x)
i <- 0
theta <- theta0
x <- cbind(1, x) # Adding 1 as first column for intercept
imprv <- 1e10
cost <- (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y)
delta <- 1
while(imprv > epsilon & i < max_iter){cost
i <- i + 1
grad <- 0
for(j in 1:length(y)){
grad_chng <- x[j, ] * c(y[j]-x[j, ] %*% theta)
grad <- grad + grad_chng
}
theta <- theta + (alpha / N) * grad
cost <- append(cost, (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y))
imprv <- abs(cost[i+1] - cost[i])
if((cost[i+1] - cost[i]) > 0) stop("Cost is increasing. Try reducing alpha.")
}
print(paste0("Stopped in ", i, " iterations"))
cost <- cost[-1]
return(list(theta,cost))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
res <- BGD(x, y, c(1, -1), alpha = 0.005, epsilon = 1e-5, max_iter = 10)
## [1] "Stopped in 10 iterations"
theta <- res[[1]]
loss <- res[[2]]
ggplot() +
geom_point(aes(x=1:length(loss), y=loss)) +
labs(x='iteration')
SGD <- function(x, y, theta0, alpha = 0.01, epsilon = 1e-8, max_iter=25000){
# Inputs
# x : The input variables (M columns)
# y : Output variables (1 column)
# theta0 : Initial weight vector (M+1 columns)
x <- as.matrix(x)
y <- as.matrix(y)
N <- nrow(x)
i <- 0
theta <- theta0
x <- cbind(1, x) # Adding 1 as first column for intercept
imprv <- 1e10
cost <- (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y)
delta <- 1
while(imprv > epsilon & i < max_iter){cost
i <- i + 1
grad <- 0
randpt <- sample(length(y),1)
grad_chng <- x[randpt, ] * c(y[randpt]-x[randpt, ] %*% theta)
grad <- grad + grad_chng
theta <- theta + (alpha / N) * grad
cost <- append(cost, (1/(2*N)) * t(x %*% theta - y) %*% (x %*% theta - y))
imprv <- abs(cost[i+1] - cost[i])
if((cost[i+1] - cost[i]) > 0) stop("Cost is increasing. Try reducing alpha.")
}
print(paste0("Stopped in ", i, " iterations"))
cost <- cost[-1]
return(list(theta,cost))
}
x <- log(Auto$horsepower)
y <- log(Auto$mpg)
res <- SGD(x, y, c(1, -1), alpha = 0.005, epsilon = 1e-5, max_iter = 10)
## [1] "Stopped in 10 iterations"
theta <- res[[1]]
loss <- res[[2]]
ggplot() +
geom_point(aes(x=1:length(loss), y=loss)) +
labs(x='iteration')
2.BGD resulted in SGD resulted in